// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <tuple>
#include <vector>

#include "base/command_line.h"
#include "base/test/scoped_feature_list.h"
#include "base/version_info/channel.h"
#include "chrome/browser/extensions/extension_browsertest.h"
#include "components/embedder_support/switches.h"
#include "content/public/common/content_switches.h"
#include "content/public/test/browser_test.h"
#include "extensions/test/result_catcher.h"
#include "extensions/test/test_extension_dir.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/features_generated.h"

// TODO(crbug.com/350642260): the prompt API for extension OT is not affecting
// ChromeOS. We have skipped the logic for ChromeOS so the test will be skipped
// as well.
namespace extensions {

namespace {

// This is the public key of tools/origin_trials/eftest.key, used to validate
// origin trial tokens generated by tools/origin_trials/generate_token.py.
constexpr char kOriginTrialPublicKeyForTesting[] =
    "dRCs+TocuKkocNKa0AtZ4awrt9XKH2SQCI6o4FY6BNA=";

constexpr char kAILanguageModelOriginTrialPermissionsField[] =
    "\"permissions\":[\"aiLanguageModelOriginTrial\"],";

// The extension origin trial token (expired on 2032-11-26) was generated by
// ```
// tools/origin_trials/generate_token.py
// chrome-extension://jnapclmfkaejhjkddbmiafekigmcbmma AIPromptAPIForExtension
// --expire-days 3000
// ```
constexpr char kLanguageModelOriginTrialTokensField[] =
    "\"trial_tokens\":[\"A5nDxhrF7Qe4GiLouR1mgL5XKSk4wXA0B/RV2VyQcZj2IkLALdG/"
    "FHrucKbG1TKD8QidNfqBdC07wP8KJaF6EQYAAAB9eyJvcmlnaW4iOiAiY2hyb21lLWV4dGVuc2"
    "lvbjovL2puYXBjbG1ma2Flamhqa2RkYm1pYWZla2lnbWNibW1hIiwgImZlYXR1cmUiOiAiQUlQ"
    "cm9tcHRBUElGb3JFeHRlbnNpb24iLCAiZXhwaXJ5IjogMTk4NTA2MjMwM30=\"],";

// The `key` field stores the public key for the extension with id
// "jnapclmfkaejhjkddbmiafekigmcbmma".
static constexpr char kManifestTemplate[] =
    R"JS(
    {
      "name": "AI assistant test",
      "version": "0.1",
      "manifest_version": 3,
      %s
      %s
      "key": "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3H6Jc0On6l0H3DJ6bx4aOW3+srCfjSdr+3ukwIEZrL6jDy500XweIwOp9PItpM9sijwu8v1rdyoBPubm/ottp/oz42aKp+2xIxcMTa6/cA2BL2kOWxwv+WP9d01IOFbFpWmQBDQNpp2UmH67OFbie6zHhyrSJKL2o9d05iX0a9Xwv9W48JKYpldo+/2JTP/5en0jxgiN+qkOCZuLag2cS/6Az0LArqsf5D+ReJemIBCNJhVxu3P0naxfEG6B6XczzuuptrX3H2vDr1LxZasLh9bzV88+8BxarjETACebOfqy366QxXluwAjnu/NHPv53edXlXvXrZ0C69RvvlMh1qQIDAQAB",
      "description": "Extension for testing the AI assistant API.",
      "background": {
        "service_worker": "sw.js"
      }
    }
  )JS";

static constexpr char kServiceWorkerScript[] =
    R"JS(
      chrome.test.runTests([
        function verifySelfAi() {
          const expectSelfAi = %s;
          chrome.test.assertEq(expectSelfAi, !!self.ai);
          chrome.test.succeed();
        },
        function verifyChromeAiOriginTrial() {
          const expectChromeAiOriginTrialAssistant = %s;
          chrome.test.assertEq(
            expectChromeAiOriginTrialAssistant,
            !!(chrome.aiOriginTrial && chrome.aiOriginTrial.languageModel)
          );
          chrome.test.succeed();
        },
      ]);
    )JS";

// The boolean tuple describing:
// 1. the feature flag value for `kEnableAIPromptAPIForWebPlatform`;
// 2. the feature flag value for `kEnableAIPromptAPIForExtension`;
// 3. if the kill switch is triggered;
// 4. if the extension requests for the right permission;
// 5. if the extension is participating in the origin trial.
using ExtensionAIAssistantBrowserTestVariant =
    std::tuple<bool, bool, bool, bool, bool>;

bool IsPromptAPIForWebPlatformEnabled(
    ExtensionAIAssistantBrowserTestVariant param) {
  return std::get<0>(param);
}

bool IsPromptAPIForExtensionEnabled(
    ExtensionAIAssistantBrowserTestVariant param) {
  return std::get<1>(param);
}

bool IsPromptAPIForExtensionKillSwitchTriggered(
    ExtensionAIAssistantBrowserTestVariant param) {
  return std::get<2>(param);
}

bool IsExtensionPermissionRequested(
    ExtensionAIAssistantBrowserTestVariant param) {
  return std::get<3>(param);
}

bool IsExtensionParticipatingInOriginTrial(
    ExtensionAIAssistantBrowserTestVariant param) {
  return std::get<4>(param);
}

// Describes the test variants in a meaningful way in the parameterized tests.
std::string DescribeTestVariant(
    const testing::TestParamInfo<ExtensionAIAssistantBrowserTestVariant> info) {
  std::string WebPlatformFeatureString = base::StringPrintf(
      "PromptAPI%sForWebPlatform",
      IsPromptAPIForWebPlatformEnabled(info.param) ? "Enabled" : "Disabled");
  std::string ExtensionFeatureString = base::StringPrintf(
      "PromptAPI%sForExtension",
      IsPromptAPIForExtensionEnabled(info.param) ? "Enabled" : "Disabled");
  std::string ExtensionPermissionString = base::StringPrintf(
      "ExtensionPermission%sRequested",
      IsExtensionPermissionRequested(info.param) ? "" : "Not");
  std::string ExtensionOriginTrialString = base::StringPrintf(
      "Extension%sParticipatingInOriginTrial",
      IsExtensionParticipatingInOriginTrial(info.param) ? "" : "Not");
  std::string KillSwitchString = base::StringPrintf(
      "KillSwitch%s", IsPromptAPIForExtensionKillSwitchTriggered(info.param)
                          ? "Triggered"
                          : "NotTriggered");

  return base::JoinString(
      {WebPlatformFeatureString, ExtensionFeatureString,
       ExtensionPermissionString, ExtensionOriginTrialString, KillSwitchString},
      "_");
}

}  // namespace

class ExtensionAIAssistantBrowserTest
    : public ExtensionBrowserTest,
      public testing::WithParamInterface<
          ExtensionAIAssistantBrowserTestVariant> {
 public:
  void SetUpCommandLine(base::CommandLine* command_line) override {
    ExtensionBrowserTest::SetUpCommandLine(command_line);

    std::vector<std::string_view> enabled_features;
    // Disable all the other AI APIs to avoid unexpected ai namespace.
    std::vector<std::string_view> disabled_features{
        "AISummarizationAPI", "AIWriterAPI", "AIRewriterAPI"};

    if (IsPromptAPIForWebPlatformEnabled(GetParam())) {
      enabled_features.push_back("AIPromptAPIForWebPlatform");
    } else {
      disabled_features.push_back("AIPromptAPIForWebPlatform");
    }
    if (IsPromptAPIForExtensionEnabled(GetParam())) {
      enabled_features.push_back("AIPromptAPIForExtension");
    } else {
      disabled_features.push_back("AIPromptAPIForExtension");
    }

    for (std::string_view& feature : enabled_features) {
      command_line->AppendSwitchASCII(switches::kEnableBlinkFeatures, feature);
    }

    for (std::string_view& feature : disabled_features) {
      command_line->AppendSwitchASCII(switches::kDisableBlinkFeatures, feature);
    }

    // Also specify the test public key to make the test token effective.
    command_line->AppendSwitchASCII(embedder_support::kOriginTrialPublicKey,
                                    kOriginTrialPublicKeyForTesting);

    // The base feature for the web platform prompt API should be enabled so we
    // don't apply the kill switch to it.
    std::vector<base::test::FeatureRefAndParams> enabled_base_features{
        {blink::features::kEnableAIPromptAPIForWebPlatform, {}}};
    std::vector<base::test::FeatureRef> disabled_base_features;
    if (IsPromptAPIForExtensionKillSwitchTriggered(GetParam())) {
      disabled_base_features.push_back(
          {blink::features::kEnableAIPromptAPIForExtension});
    } else {
      enabled_base_features.push_back(
          {blink::features::kEnableAIPromptAPIForExtension, {}});
    }
    feature_list_.InitWithFeaturesAndParameters(enabled_base_features,
                                                disabled_base_features);
  }

 protected:
  std::string GetManifest() {
    return base::StringPrintf(kManifestTemplate,
                              IsExtensionPermissionRequested(GetParam())
                                  ? kAILanguageModelOriginTrialPermissionsField
                                  : "",
                              IsExtensionParticipatingInOriginTrial(GetParam())
                                  ? kLanguageModelOriginTrialTokensField
                                  : "");
  }

 private:
  base::test::ScopedFeatureList feature_list_;
};

INSTANTIATE_TEST_SUITE_P(
    /* no prefix */,
    ExtensionAIAssistantBrowserTest,
    testing::Combine(testing::Bool(),
                     testing::Bool(),
                     testing::Bool(),
                     testing::Bool(),
                     testing::Bool()),
    &DescribeTestVariant);

#if BUILDFLAG(IS_CHROMEOS)
#define MAYBE_TestAssistantFactoryExistence \
  DISABLED_TestAssistantFactoryExistence
#else
#define MAYBE_TestAssistantFactoryExistence TestAssistantFactoryExistence
#endif  // BUILDFLAG(IS_CHROMEOS)
IN_PROC_BROWSER_TEST_P(ExtensionAIAssistantBrowserTest,
                       MAYBE_TestAssistantFactoryExistence) {
  TestExtensionDir test_dir;
  test_dir.WriteManifest(GetManifest());
  auto bool_to_str = [](bool value) { return value ? "true" : "false"; };
  bool is_self_ai_accessible = IsPromptAPIForWebPlatformEnabled(GetParam());
  bool is_chrome_ai_accessible =
      (IsPromptAPIForExtensionEnabled(GetParam()) ||
       IsExtensionParticipatingInOriginTrial(GetParam())) &&
      IsExtensionPermissionRequested(GetParam()) &&
      !IsPromptAPIForExtensionKillSwitchTriggered(GetParam());
  test_dir.WriteFile(FILE_PATH_LITERAL("sw.js"),
                     base::StringPrintf(kServiceWorkerScript,
                                        bool_to_str(is_self_ai_accessible),
                                        bool_to_str(is_chrome_ai_accessible)));
  ResultCatcher result_catcher;
  const Extension* extension = LoadExtension(test_dir.UnpackedPath());
  ASSERT_TRUE(extension);
  EXPECT_TRUE(result_catcher.GetNextResult()) << result_catcher.message();
}

}  // namespace extensions
